Summary

Key questions

  1. how to visualize their importance or generate hypotheses off gene clusters

TODO (before submission)

  1. functionalize and reproduce
  2. make prettier binning plots
In [1]:
%matplotlib inline

import pandas as pd
import os
import glob
import pickle
import phate
import scprep
import meld
import graphtools as gt
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
import time
import datetime
import scanpy as sc
from sklearn.decomposition import PCA
from py_pcha import PCHA
from mpl_toolkits.mplot3d import Axes3D
import matplotlib as mpl



# settings
plt.rc('font', size = 9)
plt.rc('font', family='sans serif')
plt.rcParams['pdf.fonttype']=42
plt.rcParams['ps.fonttype']=42
plt.rcParams['text.usetex']=False
plt.rcParams['legend.frameon']=False
plt.rcParams['axes.grid']=False
plt.rcParams['legend.markerscale']=0.5
sc.set_figure_params(dpi=300,dpi_save=600,
                     frameon=False,
                     fontsize=9)
plt.rcParams['savefig.dpi']=600
sc.settings.verbosity=2
sc._settings.ScanpyConfig.n_jobs=-1
sns.set_style("ticks")
In [26]:
dremi_data_files
Out[26]:
['/home/ngr4/project/sccovid/results/gdata_horizconcat.csv',
 '/home/ngr4/project/sccovid/results/gdata.csv']
In [2]:
pfp = '/home/ngr4/project/sccovid/results/'

dremi_data_files = [i for i in glob.glob(os.path.join(pfp,'*.csv')) if 'gdata' in i]

data = {os.path.split(f)[1].split('.csv')[0]:pd.read_csv(f, index_col=0) for f in dremi_data_files}
In [3]:
dremis_sep = data['gdata'].set_index('gene')
dremis = data['gdata_horizconcat'].set_index('gene')
del data

dremis
Out[3]:
0 1 2 3 4 5 6 7 8 9 ... 790 791 792 793 794 795 796 797 798 799
gene
AL627309.1 0.000469 0.000000 0.000000 0.0 0.0 0.0 0.0 0.0 0.0 0.0 ... 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
AL627309.3 0.000427 0.000000 0.000000 0.0 0.0 0.0 0.0 0.0 0.0 0.0 ... 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
AL627309.4 0.000430 0.000000 0.000000 0.0 0.0 0.0 0.0 0.0 0.0 0.0 ... 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
AL732372.1 0.000427 0.000000 0.000000 0.0 0.0 0.0 0.0 0.0 0.0 0.0 ... 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
AL669831.2 0.000425 0.000000 0.000000 0.0 0.0 0.0 0.0 0.0 0.0 0.0 ... 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ...
AC007325.2 0.000263 0.001748 0.000000 0.0 0.0 0.0 0.0 0.0 0.0 0.0 ... 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
AL354822.1 0.001188 0.000000 0.000000 0.0 0.0 0.0 0.0 0.0 0.0 0.0 ... 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
AC004556.1 0.000425 0.000000 0.000000 0.0 0.0 0.0 0.0 0.0 0.0 0.0 ... 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
AC240274.1 0.000964 0.000000 0.001894 0.0 0.0 0.0 0.0 0.0 0.0 0.0 ... 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
scv2_orf1-10 0.000974 0.000000 0.000000 0.0 0.0 0.0 0.0 0.0 0.0 0.0 ... 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0

24714 rows × 800 columns

In [63]:
calc_pcaphate = False
n_AT = 8

if calc_pcaphate:
    pca = PCA(n_components=100).fit(dremis)
    dremis_pca = pca.transform(dremis)

# try PCHA
if False:
    # on PCA space
    X = np.array(dremis_pca) # on PCA space
else:
    X = np.array(dremis) # on data
start=time.time()
XC, S, C, SSE, varexpl = PCHA(X.T, noc=n_AT)  # S for each cell sum to 1
print('AA on sample in {:.2f}-min'.format((time.time()-start)/60))

if calc_pcaphate:
    # plot on phate sapce
    phate_op = phate.PHATE(n_components=3).fit(dremis)
    dremis_phate = phate_op.transform(dremis)

# transform ATs
if True:
    # PCHA in data-space
    Y_pca = pca.transform(XC.T)
else:
    # PCHA on PCA data
    Y_pca = XC.T
Y_phate = phate_op.transform(XC.T)


fig, ax = plt.subplots(1,2,figsize=(6,2))
scprep.plot.scatter2d(dremis_pca,
                      ticks=None,
                      c='#f7c09c',
                      label_prefix='PCA',
                      ax=ax[0])
scprep.plot.scatter2d(dremis_phate,
                      ticks=None,
                      c='#f7c09c',
                      label_prefix='PHATE',
                      ax=ax[1])
AA on sample in 3.54-min
Calculating KNN search...
Calculated KNN search in 0.06 seconds.
Calculating affinities...
/gpfs/ycga/project/dijk/ngr4/conda_envs/rnavel/lib/python3.7/site-packages/phate/phate.py:787: RuntimeWarning: Pre-fit PHATE should not be used to transform a new data matrix. Please fit PHATE to the new data by running 'fit' with the new data.
  RuntimeWarning)
Out[63]:
<matplotlib.axes._subplots.AxesSubplot at 0x2b11ad873650>
In [64]:
p = sns.scatterplot(x=list(range(pca.explained_variance_ratio_.shape[0])),y=np.cumsum(pca.explained_variance_ratio_))
p.set_ylabel('Variance explained')
p.set_xlabel('PC')
p.set_title('Dimensionality of data')
Out[64]:
Text(0.5, 1.0, 'Dimensionality of data')
In [65]:
# plot on PCA space
fig = plt.figure(figsize=(3, 2))
plt.scatter(dremis_pca[:,0], dremis_pca[:,1], s=3, alpha=0.5, c='#f7c09c')
plt.scatter([Y_pca[:,0]], [Y_pca[:,1]], s=200, c='#616066')
plt.xticks([])
plt.yticks([])
for i in range(Y_pca.shape[0]):
    plt.text(Y_pca[i,0], Y_pca[i,1], i+1, horizontalalignment='center', verticalalignment='center', fontdict={'color': 'white','size':10,'weight':'bold'})
In [66]:
# plot on phate space
fig = plt.figure(figsize=(3, 2))
plt.scatter(dremis_phate[:,0], dremis_phate[:,1], s=3, alpha=0.5, c='#f7c09c', lw=0)
plt.scatter([Y_phate[:,0]], [Y_phate[:,1]], s=200, c='#616066')
plt.xticks([])
plt.yticks([])
for i in range(Y_phate.shape[0]):
    plt.text(Y_phate[i,0], Y_phate[i,1], i+1, horizontalalignment='center', verticalalignment='center', fontdict={'color': 'white','size':10,'weight':'bold'})
In [69]:
  
# 3d plot
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
ax.scatter(dremis_phate[:,0], dremis_phate[:,1], dremis_phate[:,2], c='skyblue', s=1, alpha=0.2)
ax.scatter(Y_phate[:,0], Y_phate[:,1], Y_phate[:,2], s=200, c='#616066')
for i in range(Y_phate.shape[0]):
    ax.text(Y_phate[i,0], Y_phate[i,1], Y_phate[i,2], i+1, horizontalalignment='center', verticalalignment='center', fontdict={'color': 'white','size':10,'weight':'bold'})
# ax.view_init(30, 185)
ax.set_xticks([])
ax.set_yticks([])
ax.set_zticks([])
ax.set_xlabel('PHATE1')
ax.set_ylabel('PHATE2')
ax.set_zlabel('PHATE3')
Out[69]:
Text(0.5, 0, 'PHATE3')
In [70]:
# try different ats

calc_pcaphate = False
n_AT = 6

if calc_pcaphate:
    pca = PCA(n_components=100).fit(dremis)
    dremis_pca = pca.transform(dremis)

# try PCHA
if False:
    # on PCA space
    X = np.array(dremis_pca) # on PCA space
else:
    X = np.array(dremis) # on data
start=time.time()
XC, S, C, SSE, varexpl = PCHA(X.T, noc=n_AT)  # S for each cell sum to 1
print('AA on sample in {:.2f}-min'.format((time.time()-start)/60))

if calc_pcaphate:
    # plot on phate sapce
    phate_op = phate.PHATE(n_components=3).fit(dremis)
    dremis_phate = phate_op.transform(dremis)

# transform ATs
if True:
    # PCHA in data-space
    Y_pca = pca.transform(XC.T)
else:
    # PCHA on PCA data
    Y_pca = XC.T
Y_phate = phate_op.transform(XC.T)


fig, ax = plt.subplots(1,2,figsize=(6,2))
scprep.plot.scatter2d(dremis_pca,
                      ticks=None,
                      c='#f7c09c',
                      label_prefix='PCA',
                      ax=ax[0])
scprep.plot.scatter2d(dremis_phate,
                      ticks=None,
                      c='#f7c09c',
                      label_prefix='PHATE',
                      ax=ax[1])

# plot on phate space
fig = plt.figure(figsize=(3, 2))
plt.scatter(dremis_phate[:,0], dremis_phate[:,1], s=3, alpha=0.5, c='#f7c09c', lw=0)
plt.scatter([Y_phate[:,0]], [Y_phate[:,1]], s=200, c='#616066')
plt.xticks([])
plt.yticks([])
for i in range(Y_phate.shape[0]):
    plt.text(Y_phate[i,0], Y_phate[i,1], i+1, horizontalalignment='center', verticalalignment='center', fontdict={'color': 'white','size':10,'weight':'bold'})


  
# 3d plot
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
ax.scatter(dremis_phate[:,0], dremis_phate[:,1], dremis_phate[:,2], c='skyblue', s=1, alpha=0.2)
ax.scatter(Y_phate[:,0], Y_phate[:,1], Y_phate[:,2], s=200, c='#616066')
for i in range(Y_phate.shape[0]):
    ax.text(Y_phate[i,0], Y_phate[i,1], Y_phate[i,2], i+1, horizontalalignment='center', verticalalignment='center', fontdict={'color': 'white','size':10,'weight':'bold'})
# ax.view_init(30, 185)
ax.set_xticks([])
ax.set_yticks([])
ax.set_zticks([])
ax.set_xlabel('PHATE1')
ax.set_ylabel('PHATE2')
ax.set_zlabel('PHATE3')
AA on sample in 5.21-min
Calculating KNN search...
Calculated KNN search in 0.04 seconds.
Calculating affinities...
/gpfs/ycga/project/dijk/ngr4/conda_envs/rnavel/lib/python3.7/site-packages/phate/phate.py:787: RuntimeWarning: Pre-fit PHATE should not be used to transform a new data matrix. Please fit PHATE to the new data by running 'fit' with the new data.
  RuntimeWarning)
Out[70]:
Text(0.5, 0, 'PHATE3')
In [71]:
# try different ats

calc_pcaphate = False
n_AT = 12

if calc_pcaphate:
    pca = PCA(n_components=100).fit(dremis)
    dremis_pca = pca.transform(dremis)

# try PCHA
if False:
    # on PCA space
    X = np.array(dremis_pca) # on PCA space
else:
    X = np.array(dremis) # on data
start=time.time()
XC, S, C, SSE, varexpl = PCHA(X.T, noc=n_AT)  # S for each cell sum to 1
print('AA on sample in {:.2f}-min'.format((time.time()-start)/60))

if calc_pcaphate:
    # plot on phate sapce
    phate_op = phate.PHATE(n_components=3).fit(dremis)
    dremis_phate = phate_op.transform(dremis)

# transform ATs
if True:
    # PCHA in data-space
    Y_pca = pca.transform(XC.T)
else:
    # PCHA on PCA data
    Y_pca = XC.T
Y_phate = phate_op.transform(XC.T)


# plot on phate space
fig = plt.figure(figsize=(3, 2))
plt.scatter(dremis_phate[:,0], dremis_phate[:,1], s=3, alpha=0.5, c='#f7c09c', lw=0)
plt.scatter([Y_phate[:,0]], [Y_phate[:,1]], s=200, c='#616066')
plt.xticks([])
plt.yticks([])
for i in range(Y_phate.shape[0]):
    plt.text(Y_phate[i,0], Y_phate[i,1], i+1, horizontalalignment='center', verticalalignment='center', fontdict={'color': 'white','size':10,'weight':'bold'})


  
# 3d plot
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
ax.scatter(dremis_phate[:,0], dremis_phate[:,1], dremis_phate[:,2], c='skyblue', s=1, alpha=0.2)
ax.scatter(Y_phate[:,0], Y_phate[:,1], Y_phate[:,2], s=200, c='#616066')
for i in range(Y_phate.shape[0]):
    ax.text(Y_phate[i,0], Y_phate[i,1], Y_phate[i,2], i+1, horizontalalignment='center', verticalalignment='center', fontdict={'color': 'white','size':10,'weight':'bold'})
# ax.view_init(30, 185)
ax.set_xticks([])
ax.set_yticks([])
ax.set_zticks([])
ax.set_xlabel('PHATE1')
ax.set_ylabel('PHATE2')
ax.set_zlabel('PHATE3')
AA on sample in 3.05-min
Calculating KNN search...
Calculated KNN search in 0.09 seconds.
Calculating affinities...
/gpfs/ycga/project/dijk/ngr4/conda_envs/rnavel/lib/python3.7/site-packages/phate/phate.py:787: RuntimeWarning: Pre-fit PHATE should not be used to transform a new data matrix. Please fit PHATE to the new data by running 'fit' with the new data.
  RuntimeWarning)
Out[71]:
Text(0.5, 0, 'PHATE3')
In [87]:
# scprep utils, REF: https://github.com/KrishnaswamyLab/scprep/blob/0d4843f35702d0d8778e7a6b843f8a4262048d60/scprep/plot/utils.py#L70
def _in_ipynb():
    """Check if we are running in a Jupyter Notebook
    Credit to https://stackoverflow.com/a/24937408/3996580
    """
    __VALID_NOTEBOOKS = [
        "<class 'google.colab._shell.Shell'>",
        "<class 'ipykernel.zmqshell.ZMQInteractiveShell'>",
    ]
    try:
        return str(type(get_ipython())) in __VALID_NOTEBOOKS
    except NameError:
        return False
    
def _get_figure(ax=None, figsize=None, subplot_kw=None):
    if subplot_kw is None:
        subplot_kw = {}
    if ax is None:
        if "projection" in subplot_kw and subplot_kw["projection"] == "3d":
            # ensure mplot3d is loaded
            Axes3D
        fig, ax = plt.subplots(figsize=figsize, subplot_kw=subplot_kw)
        show_fig = True
    else:
        try:
            fig = ax.get_figure()
        except AttributeError as e:
            if not isinstance(ax, mpl.axes.Axes):
                raise TypeError(
                    "Expected ax as a matplotlib.axes.Axes. " "Got {}".format(type(ax))
                )
            else:
                raise e
        if "projection" in subplot_kw:
            if subplot_kw["projection"] == "3d" and not isinstance(
                ax, Axes3D
            ):
                raise TypeError(
                    "Expected ax with projection='3d'. " "Got 2D axis instead."
                )
        show_fig = False
    return fig, ax, show_fig

def show(fig):
    """Show a matplotlib Figure correctly, regardless of platform
    If running a Jupyter notebook, we avoid running `fig.show`. If running
    in Windows, it is necessary to run `plt.show` rather than `fig.show`.
    Parameters
    ----------
    fig : matplotlib.Figure
        Figure to show
    """
    fig.tight_layout()
    if _mpl_is_gui_backend():
        if platform.system() == "Windows":
            plt.show(block=True)
        else:
            fig.show()
In [102]:
# rotate 3d 
filename=None
dpi=300
rotation_speed=30
fps=1
elev=None
figsize=None
ipython_html="jshtml"
ax=None

if _in_ipynb():
    # in ipynb
    # credit to
    # http://tiao.io/posts/notebooks/save-matplotlib-animations-as-gifs/
    mpl.rc("animation", html=ipython_html)

if filename is not None:
    if filename.endswith(".gif"):
        writer = "imagemagick"
    elif filename.endswith(".mp4"):
        writer = "ffmpeg"
    else:
        raise ValueError(
            "filename must end in .gif or .mp4. Got {}".format(filename)
        )

degrees_per_frame = rotation_speed / fps
frames = int(round(360 / degrees_per_frame))

# fix rounding errors
degrees_per_frame = 360 / frames
interval = 1000 * degrees_per_frame / rotation_speed
    
fig, ax, show_fig = _get_figure(ax, figsize, subplot_kw={"projection": "3d"})
# ax = fig.add_subplot(111, projection='3d')
ax.scatter(dremis_phate[:,0], dremis_phate[:,1], dremis_phate[:,2], c='skyblue', s=1, alpha=0.2)
ax.scatter(Y_phate[:,0], Y_phate[:,1], Y_phate[:,2], s=200, c='#616066')
for i in range(Y_phate.shape[0]):
    ax.text(Y_phate[i,0], Y_phate[i,1], Y_phate[i,2], i+1, horizontalalignment='center', verticalalignment='center', fontdict={'color': 'white','size':10,'weight':'bold'})
# ax.view_init(30, 185)
ax.set_xticks([])
ax.set_yticks([])
ax.set_zticks([])
ax.set_xlabel('PHATE1')
ax.set_ylabel('PHATE2')
ax.set_zlabel('PHATE3')

azim = ax.azim

def init():
    return ax

def animate(i):
    ax.view_init(azim=azim + i * degrees_per_frame, elev=elev)
    return ax

ani = mpl.animation.FuncAnimation(
    fig,
    animate,
    init_func=init,
    frames=range(frames),
    interval=interval,
    blit=False,
)

if filename is not None:
    ani.save(filename, writer=writer, dpi=dpi)
    
if _in_ipynb():
    # credit to https://stackoverflow.com/a/45573903/3996580
    plt.close(fig)
elif show_fig:
    show(fig)
    
ani
In [103]:
ani
Out[103]:
In [95]:
_in_ipynb()
Out[95]:
True
In [ ]: